package com.taxi.service;

import java.util.*;
import java.util.stream.Collectors;

/**
 * @Author：Aniu
 * @Date：2024/2/7 3:57
 * @description 批量派单->匈牙利算法
 */

public class HungarianAlgorithmZF {

    private String arrows = "----------->";

    public static void main(String[] args) {
        HungarianAlgorithmZF hungarianAlgorithmZF = new HungarianAlgorithmZF();
        // costs代价矩阵：行为乘客id,列为司机id,值为司乘距离
        int[][] costs = {{5, 0, 2, 0, 2}, {2, 3, 7, 5, 4}, {0, 10, 5, 7, 5}, {9, 8, 0, 0, 4}, {0, 6, 3, 6, 2}};
        int[][] copyCosts = hungarianAlgorithmZF.copyArr(costs);
        int[][] assignments = hungarianAlgorithmZF.solve(copyCosts);

        System.out.println("最佳指派方案：");
        for (int i = 0; i < assignments.length; i++) {
            Integer person = hungarianAlgorithmZF.findZeroIndex(assignments[i]).get(0);
            System.out.println("乘客" + i + "指派给司机" + person+"  距离为"+costs[i][person]);
        }

//        List<HashMap<Integer, Integer>> out = hungarianAlgorithmZF.out(costs);
//        System.out.println(out.toString());
    }

    /**
     * 外部调用方法
     * @param costs
     * @return
     */
    public List<HashMap<Integer,Integer>> out(int[][] costs){
        int[][] copyCosts = this.copyArr(costs);
        int[][] assignments = this.solve(copyCosts);
        ArrayList<HashMap<Integer,Integer>> list = new ArrayList<>();
        for (int i = 0; i < assignments.length; i++) {
            Integer person = this.findZeroIndex(assignments[i]).get(0);
            HashMap<Integer,Integer> map = new HashMap<>();
            map.put(i,person);
            list.add(map);
        }

        return list;
    }

    /**
     * 解决指派问题
     *
     * @param costs 成本矩阵，costs[i][j]表示将第i个任务(乘客)指派给第j个人员(司机)的成本(距离)
     * @return 最佳的指派方案
     */
    public int[][] solve(int[][] costs) {
        while (true) {
            //寻找独立0元素
            int[][] zeroArr = findIndependentZeros(costs);

            //划线
            List<List<Integer>> lineation = lineation(zeroArr);

            long rowCount = lineation.get(0).stream().distinct().count();
            long colCount = lineation.get(1).stream().distinct().count();
            if ((rowCount+colCount)==costs.length){
                return zeroArr;
            }

            //寻找没在划线中的最下元素,更新成本矩阵
            updateCost(costs, lineation);
        }
    }

    /**
     * 修改成本矩阵
     */
    private void updateCost(int[][] costs, List<List<Integer>> lineation) {
        int[] minPosition = findMinNumInOtherNum(costs, lineation);

        int minNum = costs[minPosition[0]][minPosition[1]];
        List<Integer> negativeColIndex = new ArrayList<>();
        //行-最小元素
        for (int j = 0; j < costs[minPosition[0]].length; j++) {
            costs[minPosition[0]][j] = costs[minPosition[0]][j] - minNum;
            if (costs[minPosition[0]][j]<0){
                negativeColIndex.add(j);
            }
        }
        //列+最小元素
        for (int i = 0; i < negativeColIndex.size(); i++) {
            for (int j = 0; j < costs.length; j++) {
                costs[j][negativeColIndex.get(i)] = costs[j][negativeColIndex.get(i)] + minNum;
            }
        }
    }


    //寻找独立0元素
    public int[][] findIndependentZeros(int[][] cost) {

        //每一行减去最小值
        for (int i = 0; i < cost.length; i++) {
            int minNum = findMinNum(cost[i]);
            for (int j = 0; j < cost[i].length; j++) {
                cost[i][j] = cost[i][j] - minNum;
            }
        }

        //每一列减去最小值
        for (int j = 0; j < cost[0].length; j++) {
            int[] oneColumn = getOneColumn(cost, j);
            int minNum = findMinNum(oneColumn);
            for (int i = 0; i < cost.length; i++) {
                cost[i][j] = cost[i][j] - minNum;
            }
        }
        printArr(cost);
        int[][] zeroArr = extractZero(cost);
        printArr(zeroArr);
        zeroArr = selectZero(zeroArr);
        return zeroArr;
    }

    /**
     * 在所有未被直线覆盖的元素中确定出最小的一个，从每一未被直线覆盖的行减去这个最小值，再给每一被直线覆盖的列加上这个元素，回到划线
     */
    private int[] findMinNumInOtherNum(int[][] cost, List<List<Integer>> lineation) {
        Set<Integer> row = new HashSet<>(lineation.get(0));
        Set<Integer> col = new HashSet<>(lineation.get(1));
        int min = Integer.MAX_VALUE;
        int[] index = new int[2];
        for (int i = 0; i < cost.length; i++) {
            if (row.contains(i)) {
                continue;
            }
            for (int j = 0; j < cost[i].length; j++) {
                if (col.contains(j)) {
                    continue;
                }
                if (min > cost[i][j]) {
                    min = cost[i][j];
                    index[0] = i;
                    index[1] = j;
                }
            }
        }
        return index;
    }

    /**
     * 划线:
     * 1 对没有圈○的行打“√”；
     * 2 在已打“√”的行中，对×所在列打“√”；
     * 3 在已打“√”的列中，对圈○的行打“√”；
     * 4 重复2和3步骤，直到再也找不到可以打“√”的行/列为止；
     * 5 对没有打√号的行画一横线，有打√号的列画一纵线，这就得到覆盖所有0元素的最少直线数．
     */
    private List<List<Integer>> lineation(int[][] zeroArr) {
        List<List<Integer>> lineation = new ArrayList<>();
        List<Integer> row = new ArrayList<>();
        List<Integer> newRow = new ArrayList<>();
        List<Integer> col = new ArrayList<>();
        lineation.add(newRow);
        lineation.add(col);
        //对没有圈○的行打“√”；
        for (int i = 0; i < zeroArr.length; i++) {
            List<Integer> zeroIndexList = findZeroIndex(zeroArr[i]);
            if (zeroIndexList.size() == 0) {
                row.add(i);
            }
        }

        // 重复2和3步骤，直到再也找不到可以打“√”的行/列为止
        while (true) {
            long startRowNum = row.stream().distinct().count();
            long startColNum = col.stream().distinct().count();
            //在已打“√”的行中，对×所在列打“√”
            for (int i = 0; i < row.size(); i++) {
                Integer rowIndex = row.get(i);
                int[] rowArr = zeroArr[rowIndex];
                List<Integer> errIndexList = findNumIndex(rowArr, -1);
                col.addAll(errIndexList);
            }

            //在已打“√”的列中，对圈○的行打“√”；
            for (int i = 0; i < col.size(); i++) {
                Integer colIndex = col.get(i);
                int[] oneColumn = getOneColumn(zeroArr, colIndex);
                List<Integer> zeroIndexList = findZeroIndex(oneColumn);
                row.addAll(zeroIndexList);
            }
            long endRowNum = row.stream().distinct().count();
            long endColNum = col.stream().distinct().count();
            if (endRowNum == startRowNum && startColNum == endColNum) {
                break;
            }
        }
        //对没有打√号的行画一横线，有打√号的列画一纵线，这就得到覆盖所有0元素的最少直线数．

        for (int i = 0; i < zeroArr.length; i++) {
            Set<Integer> rowSet = new HashSet<>(row);
            if (!rowSet.contains(i)){
                newRow.add(i);
            }
        }
        return lineation;
    }


    /**
     * 选择zeroArr中的0元素,选择的记为0,不选择的记为-1
     */
    private int[][] selectZero(int[][] zeroArr) {
        List<Integer[]> zeroIndexSign = findZeroIndex(zeroArr);
        Set<String> signSet = zeroIndexSign.stream().map(x -> x[0] + "-" + x[1]).collect(Collectors.toSet());

        zeroArr = loopSelectZero(zeroArr, signSet);

        while (signSet.size() != 0) {
            //存在循环0元素,找到第一个0元素,进行标记
            String next = signSet.iterator().next();
            String[] split = next.split("-");
            signSet.remove(next);
            //划掉行/列的所有零元素
            int rowIndex = Integer.parseInt(split[0]);
            int colIndex = Integer.parseInt(split[1]);

            //划掉列的所有零元素
            for (int i = 0; i < zeroArr.length; i++) {
                if (zeroArr[i][colIndex] == 0 && i != rowIndex) {
                    zeroArr[i][colIndex] = -1;
                    signSet.remove(i + "-" + colIndex);
                }
            }
            //划掉行的所有零元素
            for (int i = 0; i < zeroArr[rowIndex].length; i++) {
                if (zeroArr[rowIndex][i] == 0 && i != colIndex) {
                    zeroArr[rowIndex][i] = -1;
                    signSet.remove(rowIndex + "-" + i);
                }
            }
            printArr(zeroArr);
            if (signSet.size()==0){
                return zeroArr;
            }
            zeroArr = loopSelectZero(zeroArr, signSet);
        }
        return zeroArr;
    }

    private int[][] loopSelectZero(int[][] zeroArr, Set<String> signSet) {
        while (signSet.size() > 0) {
            int startSignSize = signSet.size();
            for (int i = 0; i < zeroArr.length; i++) {
                List<Integer> zeroIndexList = findZeroIndex(zeroArr[i]);
                if (zeroIndexList.size() == 1) {
                    //为行独立0元素,划掉本列的所有零元素
                    int zeroIndex = zeroIndexList.get(0);
                    signSet.remove(i + "-" + zeroIndex);
                    for (int j = 0; j < zeroArr.length; j++) {
                        if (zeroArr[j][zeroIndex] == 0 && j != i) {
                            zeroArr[j][zeroIndex] = -1;
                            signSet.remove(j + "-" + zeroIndex);
                        }
                    }
                }
            }

            System.out.println("寻找行0元素");
            printArr(zeroArr);

            for (int i = 0; i < zeroArr[0].length; i++) {
                int[] oneColumn = getOneColumn(zeroArr, i);
                List<Integer> zeroIndexList = findZeroIndex(oneColumn);
                if (zeroIndexList.size() == 1) {
                    //为列独立0元素,划掉本行的所有零元素
                    int zeroIndex = zeroIndexList.get(0);
                    signSet.remove(zeroIndex + "-" + i);
                    for (int j = 0; j < zeroArr[zeroIndex].length; j++) {
                        if (zeroArr[zeroIndex][j] == 0 && j != i) {
                            zeroArr[zeroIndex][j] = -1;
                            signSet.remove(zeroIndex + "-" + j);
                        }
                    }
                }
            }
            int endSignSize = signSet.size();
            System.out.println("寻找列0元素");
            printArr(zeroArr);
            System.out.println("本次标记0元素"+(startSignSize - endSignSize)+"个,还剩"+endSignSize+"个");
            if (startSignSize - endSignSize == 0) {
                break;
            }
        }
        return zeroArr;
    }


    //检查是否有n个独立0元素
    private boolean checkHasNZero(int[][] zeroArr) {
        long sum = Arrays.stream(zeroArr).mapToLong(x -> Arrays.stream(x).filter(y -> y == 0).count()).sum();
        return sum == zeroArr.length;
    }


    /**
     * 寻找0元素的位置
     */
    private List<Integer> findZeroIndex(int[] arr) {
        List<Integer> list = new ArrayList<>();
        for (int i = 0; i < arr.length; i++) {
            if (arr[i] == 0) {
                list.add(i);
            }
        }
        return list;
    }


    /**
     * 寻找元素num的位置
     */
    private List<Integer> findNumIndex(int[] arr, int num) {
        List<Integer> list = new ArrayList<>();
        for (int i = 0; i < arr.length; i++) {
            if (arr[i] == num) {
                list.add(i);
            }
        }
        return list;
    }

    /**
     * 寻找0元素的位置
     */
    private List<Integer[]> findZeroIndex(int[][] arr) {
        List<Integer[]> list = new ArrayList<>();
        for (int i = 0; i < arr.length; i++) {
            for (int j = 0; j < arr[0].length; j++) {
                if (arr[i][j] == 0) {
                    Integer[] index = new Integer[2];
                    index[0] = i;
                    index[1] = j;
                    list.add(index);
                }
            }
        }
        return list;
    }


    /**
     * 找向量中的最小数
     */
    int findMinNum(int[] arr) {
        int min = arr[0];
        for (int i = 0; i < arr.length; i++) {
            if (min > arr[i]) {
                min = arr[i];
            }
        }
        return min;
    }


    /**
     * 取其中的一列
     */
    int[] getOneColumn(int[][] arr, int colIndex) {
        int[] col = new int[arr.length];
        for (int i = 0; i < arr.length; i++) {
            col[i] = arr[i][colIndex];
        }
        return col;
    }


    /**
     * 提取矩阵中的0元素
     */
    int[][] extractZero(int[][] arr) {
        int[][] zeros = new int[arr.length][arr[0].length];
        for (int i = 0; i < arr.length; i++) {
            for (int j = 0; j < arr[0].length; j++) {
                if (arr[i][j] != 0) {
                    zeros[i][j] = 1;
                }
            }
        }
        return zeros;
    }


    /**
     * 数组打印
     */
    void printArr(int[][] arr) {
        for (int i = 0; i < arr.length; i++) {
            for (int j = 0; j < arr[0].length; j++) {
                System.out.print(arr[i][j] + " ");
            }
            System.out.println();
        }
        System.out.println(arrows);
    }

    /**
     * 深拷贝数组
     */
    int[][] copyArr(int[][] arr) {
        int[][] newArr = new int[arr.length][arr[0].length];
        for (int i = 0; i < arr.length; i++) {
            for (int j = 0; j < arr[0].length; j++) {
                newArr[i][j] = arr[i][j];
            }
        }
        return newArr;
    }

}
